{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.3 注意力评分函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 10.3.1 遮蔽Softmax操作"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def masked_softmax(X, valid_lens):\n",
    "    # X: 3D tensor, valid_lens: 1D or 2D tensor\n",
    "    if valid_lens is None:\n",
    "        return nn.functional.softmax(X, dim=-1)\n",
    "    else:\n",
    "        shape = X.shape\n",
    "        if valid_lens.dim() == 1:\n",
    "            valid_lens = torch.repeat_interleave(valid_lens, shape[1])\n",
    "        else:\n",
    "            valid_lens = valid_lens.reshape(-1)\n",
    "        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)\n",
    "        return nn.functional.softmax(X.reshape(shape), dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.3104, 0.6896, 0.0000, 0.0000],\n",
       "         [0.5721, 0.4279, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.4290, 0.3759, 0.1951, 0.0000],\n",
       "         [0.2796, 0.4167, 0.3037, 0.0000]]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.4323, 0.2811, 0.2866, 0.0000]],\n",
       "\n",
       "        [[0.4222, 0.5778, 0.0000, 0.0000],\n",
       "         [0.2169, 0.2019, 0.2159, 0.3653]]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 10.3.2 加性注意力"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AdditiveAttention(nn.Module):\n",
    "    \"\"\"加性注意力\"\"\"\n",
    "    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):\n",
    "        super(AdditiveAttention, self).__init__(**kwargs)\n",
    "        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)\n",
    "        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)\n",
    "        self.w_v = nn.Linear(num_hiddens, 1, bias=False)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, queries, keys, values, valid_lens):\n",
    "        queries, keys = self.W_q(queries), self.W_k(keys)\n",
    "        # 为便于求和采用维度扩展，扩展后的维度为：\n",
    "        # queries: (batch_size, query_size, 1, num_hidden)\n",
    "        # keys: (batch_size, 1, key_size, num_hidden)\n",
    "        # 采用广播方式进行求和\n",
    "        features = queries.unsqueeze(2) + keys.unsqueeze(1)\n",
    "        features = torch.tanh(features)\n",
    "        # 从评分中移除最后一个维度\n",
    "        # scores: (batch_size, query_size, 键值对的个数)\n",
    "        scores = self.w_v(features).squeeze(-1)\n",
    "        self.attention_weights = masked_softmax(scores, valid_lens)\n",
    "        # values: (batch_size, 键值对的个数, 值的维度)\n",
    "        return torch.bmm(self.dropout(self.attention_weights), values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[-0.0660,  1.0947,  0.0104, -0.5949, -0.3940, -1.4790, -1.3675,\n",
       "           -1.3994, -0.0803,  0.3751, -1.7337, -1.0699, -0.2337,  0.5680,\n",
       "           -0.5434,  0.4628, -0.0040, -0.3612, -2.7803, -0.4253]],\n",
       " \n",
       "         [[ 2.9150,  0.4798,  1.3979, -1.0488, -1.8972,  1.3426,  0.8674,\n",
       "            0.1136, -0.7063, -0.6918,  0.0405,  0.7316, -0.1360,  1.7512,\n",
       "            0.3680,  0.2736, -1.1909,  0.7521,  0.3385, -0.6900]]]),\n",
       " tensor([[[1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.]],\n",
       " \n",
       "         [[1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.]]]),\n",
       " tensor([[[ 0.,  1.,  2.,  3.],\n",
       "          [ 4.,  5.,  6.,  7.],\n",
       "          [ 8.,  9., 10., 11.],\n",
       "          [12., 13., 14., 15.],\n",
       "          [16., 17., 18., 19.],\n",
       "          [20., 21., 22., 23.],\n",
       "          [24., 25., 26., 27.],\n",
       "          [28., 29., 30., 31.],\n",
       "          [32., 33., 34., 35.],\n",
       "          [36., 37., 38., 39.]],\n",
       " \n",
       "         [[ 0.,  1.,  2.,  3.],\n",
       "          [ 4.,  5.,  6.,  7.],\n",
       "          [ 8.,  9., 10., 11.],\n",
       "          [12., 13., 14., 15.],\n",
       "          [16., 17., 18., 19.],\n",
       "          [20., 21., 22., 23.],\n",
       "          [24., 25., 26., 27.],\n",
       "          [28., 29., 30., 31.],\n",
       "          [32., 33., 34., 35.],\n",
       "          [36., 37., 38., 39.]]]))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))\n",
    "values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)\n",
    "valid_lens = torch.tensor([2, 6])\n",
    "queries, keys, values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],\n",
       "\n",
       "        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)\n",
    "attention.eval()\n",
    "attention(queries, keys, values, valid_lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Created with matplotlib (https://matplotlib.org/) -->\n",
       "<svg height=\"101.818906pt\" version=\"1.1\" viewBox=\"0 0 186.99575 101.818906\" width=\"186.99575pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       " <defs>\n",
       "  <style type=\"text/css\">\n",
       "*{stroke-linecap:butt;stroke-linejoin:round;}\n",
       "  </style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M -0 101.818906 \n",
       "L 186.99575 101.818906 \n",
       "L 186.99575 0 \n",
       "L -0 0 \n",
       "z\n",
       "\" style=\"fill:none;\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 145.840625 59.13 \n",
       "L 145.840625 36.81 \n",
       "L 34.240625 36.81 \n",
       "z\n",
       "\" style=\"fill:#ffffff;\"/>\n",
       "   </g>\n",
       "   <g clip-path=\"url(#p35918af0c6)\">\n",
       "    <image height=\"23\" id=\"image20e4282e39\" transform=\"scale(1 -1)translate(0 -23)\" width=\"112\" x=\"34.240625\" xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAHAAAAAXCAYAAADTEcupAAAABHNCSVQICAgIfAhkiAAAAHhJREFUaIHt2LEJgEAQBdF/mlqfLRhag12Y2ZwtmAlebiZyHAPzClgWhk223Mf2REmScV57r/DZ0HsB/WNAOAPCGRDOgHAGhDMgnAHhDAhnQDgDwpUlU5Nf6H6dLcbqxQuEMyCcAeEMCGdAOAPCGRDOgHAGhDMgXAXNiAfzAcIdmQAAAABJRU5ErkJggg==\" y=\"-36.13\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" id=\"md9518025b6\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"39.820625\" xlink:href=\"#md9518025b6\" y=\"59.13\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 31.78125 66.40625 \n",
       "Q 24.171875 66.40625 20.328125 58.90625 \n",
       "Q 16.5 51.421875 16.5 36.375 \n",
       "Q 16.5 21.390625 20.328125 13.890625 \n",
       "Q 24.171875 6.390625 31.78125 6.390625 \n",
       "Q 39.453125 6.390625 43.28125 13.890625 \n",
       "Q 47.125 21.390625 47.125 36.375 \n",
       "Q 47.125 51.421875 43.28125 58.90625 \n",
       "Q 39.453125 66.40625 31.78125 66.40625 \n",
       "z\n",
       "M 31.78125 74.21875 \n",
       "Q 44.046875 74.21875 50.515625 64.515625 \n",
       "Q 56.984375 54.828125 56.984375 36.375 \n",
       "Q 56.984375 17.96875 50.515625 8.265625 \n",
       "Q 44.046875 -1.421875 31.78125 -1.421875 \n",
       "Q 19.53125 -1.421875 13.0625 8.265625 \n",
       "Q 6.59375 17.96875 6.59375 36.375 \n",
       "Q 6.59375 54.828125 13.0625 64.515625 \n",
       "Q 19.53125 74.21875 31.78125 74.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-48\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(36.639375 73.728437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_2\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"95.620625\" xlink:href=\"#md9518025b6\" y=\"59.13\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 5 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.796875 72.90625 \n",
       "L 49.515625 72.90625 \n",
       "L 49.515625 64.59375 \n",
       "L 19.828125 64.59375 \n",
       "L 19.828125 46.734375 \n",
       "Q 21.96875 47.46875 24.109375 47.828125 \n",
       "Q 26.265625 48.1875 28.421875 48.1875 \n",
       "Q 40.625 48.1875 47.75 41.5 \n",
       "Q 54.890625 34.8125 54.890625 23.390625 \n",
       "Q 54.890625 11.625 47.5625 5.09375 \n",
       "Q 40.234375 -1.421875 26.90625 -1.421875 \n",
       "Q 22.3125 -1.421875 17.546875 -0.640625 \n",
       "Q 12.796875 0.140625 7.71875 1.703125 \n",
       "L 7.71875 11.625 \n",
       "Q 12.109375 9.234375 16.796875 8.0625 \n",
       "Q 21.484375 6.890625 26.703125 6.890625 \n",
       "Q 35.15625 6.890625 40.078125 11.328125 \n",
       "Q 45.015625 15.765625 45.015625 23.390625 \n",
       "Q 45.015625 31 40.078125 35.4375 \n",
       "Q 35.15625 39.890625 26.703125 39.890625 \n",
       "Q 22.75 39.890625 18.8125 39.015625 \n",
       "Q 14.890625 38.140625 10.796875 36.28125 \n",
       "z\n",
       "\" id=\"DejaVuSans-53\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(92.439375 73.728437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_3\">\n",
       "     <!-- Keys -->\n",
       "     <defs>\n",
       "      <path d=\"M 9.8125 72.90625 \n",
       "L 19.671875 72.90625 \n",
       "L 19.671875 42.09375 \n",
       "L 52.390625 72.90625 \n",
       "L 65.09375 72.90625 \n",
       "L 28.90625 38.921875 \n",
       "L 67.671875 0 \n",
       "L 54.6875 0 \n",
       "L 19.671875 35.109375 \n",
       "L 19.671875 0 \n",
       "L 9.8125 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-75\"/>\n",
       "      <path d=\"M 56.203125 29.59375 \n",
       "L 56.203125 25.203125 \n",
       "L 14.890625 25.203125 \n",
       "Q 15.484375 15.921875 20.484375 11.0625 \n",
       "Q 25.484375 6.203125 34.421875 6.203125 \n",
       "Q 39.59375 6.203125 44.453125 7.46875 \n",
       "Q 49.3125 8.734375 54.109375 11.28125 \n",
       "L 54.109375 2.78125 \n",
       "Q 49.265625 0.734375 44.1875 -0.34375 \n",
       "Q 39.109375 -1.421875 33.890625 -1.421875 \n",
       "Q 20.796875 -1.421875 13.15625 6.1875 \n",
       "Q 5.515625 13.8125 5.515625 26.8125 \n",
       "Q 5.515625 40.234375 12.765625 48.109375 \n",
       "Q 20.015625 56 32.328125 56 \n",
       "Q 43.359375 56 49.78125 48.890625 \n",
       "Q 56.203125 41.796875 56.203125 29.59375 \n",
       "z\n",
       "M 47.21875 32.234375 \n",
       "Q 47.125 39.59375 43.09375 43.984375 \n",
       "Q 39.0625 48.390625 32.421875 48.390625 \n",
       "Q 24.90625 48.390625 20.390625 44.140625 \n",
       "Q 15.875 39.890625 15.1875 32.171875 \n",
       "z\n",
       "\" id=\"DejaVuSans-101\"/>\n",
       "      <path d=\"M 32.171875 -5.078125 \n",
       "Q 28.375 -14.84375 24.75 -17.8125 \n",
       "Q 21.140625 -20.796875 15.09375 -20.796875 \n",
       "L 7.90625 -20.796875 \n",
       "L 7.90625 -13.28125 \n",
       "L 13.1875 -13.28125 \n",
       "Q 16.890625 -13.28125 18.9375 -11.515625 \n",
       "Q 21 -9.765625 23.484375 -3.21875 \n",
       "L 25.09375 0.875 \n",
       "L 2.984375 54.6875 \n",
       "L 12.5 54.6875 \n",
       "L 29.59375 11.921875 \n",
       "L 46.6875 54.6875 \n",
       "L 56.203125 54.6875 \n",
       "z\n",
       "\" id=\"DejaVuSans-121\"/>\n",
       "      <path d=\"M 44.28125 53.078125 \n",
       "L 44.28125 44.578125 \n",
       "Q 40.484375 46.53125 36.375 47.5 \n",
       "Q 32.28125 48.484375 27.875 48.484375 \n",
       "Q 21.1875 48.484375 17.84375 46.4375 \n",
       "Q 14.5 44.390625 14.5 40.28125 \n",
       "Q 14.5 37.15625 16.890625 35.375 \n",
       "Q 19.28125 33.59375 26.515625 31.984375 \n",
       "L 29.59375 31.296875 \n",
       "Q 39.15625 29.25 43.1875 25.515625 \n",
       "Q 47.21875 21.78125 47.21875 15.09375 \n",
       "Q 47.21875 7.46875 41.1875 3.015625 \n",
       "Q 35.15625 -1.421875 24.609375 -1.421875 \n",
       "Q 20.21875 -1.421875 15.453125 -0.5625 \n",
       "Q 10.6875 0.296875 5.421875 2 \n",
       "L 5.421875 11.28125 \n",
       "Q 10.40625 8.6875 15.234375 7.390625 \n",
       "Q 20.0625 6.109375 24.8125 6.109375 \n",
       "Q 31.15625 6.109375 34.5625 8.28125 \n",
       "Q 37.984375 10.453125 37.984375 14.40625 \n",
       "Q 37.984375 18.0625 35.515625 20.015625 \n",
       "Q 33.0625 21.96875 24.703125 23.78125 \n",
       "L 21.578125 24.515625 \n",
       "Q 13.234375 26.265625 9.515625 29.90625 \n",
       "Q 5.8125 33.546875 5.8125 39.890625 \n",
       "Q 5.8125 47.609375 11.28125 51.796875 \n",
       "Q 16.75 56 26.8125 56 \n",
       "Q 31.78125 56 36.171875 55.265625 \n",
       "Q 40.578125 54.546875 44.28125 53.078125 \n",
       "z\n",
       "\" id=\"DejaVuSans-115\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(78.125 87.406562)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-75\"/>\n",
       "      <use x=\"65.498047\" xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"127.021484\" xlink:href=\"#DejaVuSans-121\"/>\n",
       "      <use x=\"186.201172\" xlink:href=\"#DejaVuSans-115\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" id=\"m0332421316\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"34.240625\" xlink:href=\"#m0332421316\" y=\"42.39\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(20.878125 46.189219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"34.240625\" xlink:href=\"#m0332421316\" y=\"53.55\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 1 -->\n",
       "      <defs>\n",
       "       <path d=\"M 12.40625 8.296875 \n",
       "L 28.515625 8.296875 \n",
       "L 28.515625 63.921875 \n",
       "L 10.984375 60.40625 \n",
       "L 10.984375 69.390625 \n",
       "L 28.421875 72.90625 \n",
       "L 38.28125 72.90625 \n",
       "L 38.28125 8.296875 \n",
       "L 54.390625 8.296875 \n",
       "L 54.390625 0 \n",
       "L 12.40625 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-49\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 57.349219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- Queries -->\n",
       "     <defs>\n",
       "      <path d=\"M 39.40625 66.21875 \n",
       "Q 28.65625 66.21875 22.328125 58.203125 \n",
       "Q 16.015625 50.203125 16.015625 36.375 \n",
       "Q 16.015625 22.609375 22.328125 14.59375 \n",
       "Q 28.65625 6.59375 39.40625 6.59375 \n",
       "Q 50.140625 6.59375 56.421875 14.59375 \n",
       "Q 62.703125 22.609375 62.703125 36.375 \n",
       "Q 62.703125 50.203125 56.421875 58.203125 \n",
       "Q 50.140625 66.21875 39.40625 66.21875 \n",
       "z\n",
       "M 53.21875 1.3125 \n",
       "L 66.21875 -12.890625 \n",
       "L 54.296875 -12.890625 \n",
       "L 43.5 -1.21875 \n",
       "Q 41.890625 -1.3125 41.03125 -1.359375 \n",
       "Q 40.1875 -1.421875 39.40625 -1.421875 \n",
       "Q 24.03125 -1.421875 14.8125 8.859375 \n",
       "Q 5.609375 19.140625 5.609375 36.375 \n",
       "Q 5.609375 53.65625 14.8125 63.9375 \n",
       "Q 24.03125 74.21875 39.40625 74.21875 \n",
       "Q 54.734375 74.21875 63.90625 63.9375 \n",
       "Q 73.09375 53.65625 73.09375 36.375 \n",
       "Q 73.09375 23.6875 67.984375 14.640625 \n",
       "Q 62.890625 5.609375 53.21875 1.3125 \n",
       "z\n",
       "\" id=\"DejaVuSans-81\"/>\n",
       "      <path d=\"M 8.5 21.578125 \n",
       "L 8.5 54.6875 \n",
       "L 17.484375 54.6875 \n",
       "L 17.484375 21.921875 \n",
       "Q 17.484375 14.15625 20.5 10.265625 \n",
       "Q 23.53125 6.390625 29.59375 6.390625 \n",
       "Q 36.859375 6.390625 41.078125 11.03125 \n",
       "Q 45.3125 15.671875 45.3125 23.6875 \n",
       "L 45.3125 54.6875 \n",
       "L 54.296875 54.6875 \n",
       "L 54.296875 0 \n",
       "L 45.3125 0 \n",
       "L 45.3125 8.40625 \n",
       "Q 42.046875 3.421875 37.71875 1 \n",
       "Q 33.40625 -1.421875 27.6875 -1.421875 \n",
       "Q 18.265625 -1.421875 13.375 4.4375 \n",
       "Q 8.5 10.296875 8.5 21.578125 \n",
       "z\n",
       "M 31.109375 56 \n",
       "z\n",
       "\" id=\"DejaVuSans-117\"/>\n",
       "      <path d=\"M 41.109375 46.296875 \n",
       "Q 39.59375 47.171875 37.8125 47.578125 \n",
       "Q 36.03125 48 33.890625 48 \n",
       "Q 26.265625 48 22.1875 43.046875 \n",
       "Q 18.109375 38.09375 18.109375 28.8125 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.1875 \n",
       "Q 20.953125 51.171875 25.484375 53.578125 \n",
       "Q 30.03125 56 36.53125 56 \n",
       "Q 37.453125 56 38.578125 55.875 \n",
       "Q 39.703125 55.765625 41.0625 55.515625 \n",
       "z\n",
       "\" id=\"DejaVuSans-114\"/>\n",
       "      <path d=\"M 9.421875 54.6875 \n",
       "L 18.40625 54.6875 \n",
       "L 18.40625 0 \n",
       "L 9.421875 0 \n",
       "z\n",
       "M 9.421875 75.984375 \n",
       "L 18.40625 75.984375 \n",
       "L 18.40625 64.59375 \n",
       "L 9.421875 64.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-105\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(14.798437 67.277031)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-81\"/>\n",
       "      <use x=\"78.710938\" xlink:href=\"#DejaVuSans-117\"/>\n",
       "      <use x=\"142.089844\" xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"203.613281\" xlink:href=\"#DejaVuSans-114\"/>\n",
       "      <use x=\"244.726562\" xlink:href=\"#DejaVuSans-105\"/>\n",
       "      <use x=\"272.509766\" xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"334.033203\" xlink:href=\"#DejaVuSans-115\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 34.240625 36.81 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 145.840625 59.13 \n",
       "L 145.840625 36.81 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 145.840625 59.13 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 34.240625 36.81 \n",
       "L 145.840625 36.81 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "  </g>\n",
       "  <g id=\"axes_2\">\n",
       "   <g id=\"patch_7\">\n",
       "    <path clip-path=\"url(#p97c982ec0f)\" d=\"M 152.815625 88.74 \n",
       "L 152.815625 88.421484 \n",
       "L 152.815625 7.518516 \n",
       "L 152.815625 7.2 \n",
       "L 156.892625 7.2 \n",
       "L 156.892625 7.518516 \n",
       "L 156.892625 88.421484 \n",
       "L 156.892625 88.74 \n",
       "z\n",
       "\" style=\"fill:#ffffff;stroke:#ffffff;stroke-linejoin:miter;stroke-width:0.01;\"/>\n",
       "   </g>\n",
       "   <image height=\"82\" id=\"image9c70671d0b\" transform=\"scale(1 -1)translate(0 -82)\" width=\"4\" x=\"153\" xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAAQAAABSCAYAAABzJnWUAAAABHNCSVQICAgIfAhkiAAAAJ5JREFUOI2dkjsOQjEMBI307n9UKgoU/2jJ7JMMSZfReNeK8ujXs+3rXJZhO6giyAH0aJwAWSwctaNhSRAc8cXQ/zO0djRYexM6ZhwYUpt89R/2AGgxJNTn1QXwjwlwgBJjyaaSUYNRbNEMgmBoZsOI0SDI5EjICIwQQ2qbxn63K02MA2CDcbMHgDPDabwLYAngyGKojEiLGFK7P7HZBwTeyrmHgLkCAAAAAElFTkSuQmCC\" y=\"-6\"/>\n",
       "   <g id=\"matplotlib.axis_3\"/>\n",
       "   <g id=\"matplotlib.axis_4\">\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 3.5 0 \n",
       "\" id=\"me76c114e76\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#me76c114e76\" y=\"88.74\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 0.0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.6875 12.40625 \n",
       "L 21 12.40625 \n",
       "L 21 0 \n",
       "L 10.6875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-46\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(163.892625 92.539219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#me76c114e76\" y=\"56.124\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0.2 -->\n",
       "      <defs>\n",
       "       <path d=\"M 19.1875 8.296875 \n",
       "L 53.609375 8.296875 \n",
       "L 53.609375 0 \n",
       "L 7.328125 0 \n",
       "L 7.328125 8.296875 \n",
       "Q 12.9375 14.109375 22.625 23.890625 \n",
       "Q 32.328125 33.6875 34.8125 36.53125 \n",
       "Q 39.546875 41.84375 41.421875 45.53125 \n",
       "Q 43.3125 49.21875 43.3125 52.78125 \n",
       "Q 43.3125 58.59375 39.234375 62.25 \n",
       "Q 35.15625 65.921875 28.609375 65.921875 \n",
       "Q 23.96875 65.921875 18.8125 64.3125 \n",
       "Q 13.671875 62.703125 7.8125 59.421875 \n",
       "L 7.8125 69.390625 \n",
       "Q 13.765625 71.78125 18.9375 73 \n",
       "Q 24.125 74.21875 28.421875 74.21875 \n",
       "Q 39.75 74.21875 46.484375 68.546875 \n",
       "Q 53.21875 62.890625 53.21875 53.421875 \n",
       "Q 53.21875 48.921875 51.53125 44.890625 \n",
       "Q 49.859375 40.875 45.40625 35.40625 \n",
       "Q 44.1875 33.984375 37.640625 27.21875 \n",
       "Q 31.109375 20.453125 19.1875 8.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-50\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(163.892625 59.923219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#me76c114e76\" y=\"23.508\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 0.4 -->\n",
       "      <defs>\n",
       "       <path d=\"M 37.796875 64.3125 \n",
       "L 12.890625 25.390625 \n",
       "L 37.796875 25.390625 \n",
       "z\n",
       "M 35.203125 72.90625 \n",
       "L 47.609375 72.90625 \n",
       "L 47.609375 25.390625 \n",
       "L 58.015625 25.390625 \n",
       "L 58.015625 17.1875 \n",
       "L 47.609375 17.1875 \n",
       "L 47.609375 0 \n",
       "L 37.796875 0 \n",
       "L 37.796875 17.1875 \n",
       "L 4.890625 17.1875 \n",
       "L 4.890625 26.703125 \n",
       "z\n",
       "\" id=\"DejaVuSans-52\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(163.892625 27.307219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-52\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"patch_8\">\n",
       "    <path d=\"M 152.815625 88.74 \n",
       "L 152.815625 88.421484 \n",
       "L 152.815625 7.518516 \n",
       "L 152.815625 7.2 \n",
       "L 156.892625 7.2 \n",
       "L 156.892625 7.518516 \n",
       "L 156.892625 88.421484 \n",
       "L 156.892625 88.74 \n",
       "z\n",
       "\" style=\"fill:none;stroke:#000000;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"p35918af0c6\">\n",
       "   <rect height=\"22.32\" width=\"111.6\" x=\"34.240625\" y=\"36.81\"/>\n",
       "  </clipPath>\n",
       "  <clipPath id=\"p97c982ec0f\">\n",
       "   <rect height=\"81.54\" width=\"4.077\" x=\"152.815625\" y=\"7.2\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 180x180 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n",
    "                 xlabel='Keys', ylabel='Queries')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 1, 20]) torch.Size([2, 10, 2])\n",
      "torch.Size([2, 1, 8]) torch.Size([2, 10, 8])\n",
      "torch.Size([2, 1, 1, 8]) torch.Size([2, 1, 10, 8])\n",
      "torch.Size([2, 1, 10, 8])\n",
      "torch.Size([2, 1, 10, 8])\n"
     ]
    }
   ],
   "source": [
    "q1, k1 = torch.ones((2, 1, 20)), torch.ones((2, 10, 2))\n",
    "print(q1.shape, k1.shape)\n",
    "W_q = nn.Linear(20, 8, bias=False)\n",
    "W_k = nn.Linear(2, 8, bias=False)\n",
    "q1, k1 = W_q(q1), W_k(k1)\n",
    "print(q1.shape, k1.shape)\n",
    "q1u, k1u = q1.unsqueeze(2), k1.unsqueeze(1)\n",
    "print(q1u.shape, k1u.shape)\n",
    "features = q1u + k1u\n",
    "print(features.shape)\n",
    "features = torch.tanh(features)\n",
    "print(features.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 10.3.3 缩放点积注意力"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DotProductAttention(nn.Module):\n",
    "    \"\"\"缩放点积注意力\"\"\"\n",
    "    def __init__(self, dropout, **kwargs):\n",
    "        super(DotProductAttention, self).__init__(**kwargs)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, queries, keys, values, valid_lens=None):\n",
    "        d = queries.shape[-1]\n",
    "        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)\n",
    "        self.attention_weights = masked_softmax(scores, valid_lens)\n",
    "        return torch.bmm(self.dropout(self.attention_weights), values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[ 1.0486,  0.7426]],\n",
       " \n",
       "         [[ 0.7217, -0.7056]]]), tensor([[[1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.]],\n",
       " \n",
       "         [[1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.],\n",
       "          [1., 1.]]]), tensor([[[ 0.,  1.,  2.,  3.],\n",
       "          [ 4.,  5.,  6.,  7.],\n",
       "          [ 8.,  9., 10., 11.],\n",
       "          [12., 13., 14., 15.],\n",
       "          [16., 17., 18., 19.],\n",
       "          [20., 21., 22., 23.],\n",
       "          [24., 25., 26., 27.],\n",
       "          [28., 29., 30., 31.],\n",
       "          [32., 33., 34., 35.],\n",
       "          [36., 37., 38., 39.]],\n",
       " \n",
       "         [[ 0.,  1.,  2.,  3.],\n",
       "          [ 4.,  5.,  6.,  7.],\n",
       "          [ 8.,  9., 10., 11.],\n",
       "          [12., 13., 14., 15.],\n",
       "          [16., 17., 18., 19.],\n",
       "          [20., 21., 22., 23.],\n",
       "          [24., 25., 26., 27.],\n",
       "          [28., 29., 30., 31.],\n",
       "          [32., 33., 34., 35.],\n",
       "          [36., 37., 38., 39.]]]))"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "queries = torch.normal(0, 1, (2, 1, 2))\n",
    "queries, keys, values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],\n",
       "\n",
       "        [[10.0000, 11.0000, 12.0000, 13.0000]]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attention = DotProductAttention(dropout=0.5)\n",
    "attention.eval()\n",
    "attention(queries, keys, values, valid_lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Created with matplotlib (https://matplotlib.org/) -->\n",
       "<svg height=\"101.818906pt\" version=\"1.1\" viewBox=\"0 0 186.99575 101.818906\" width=\"186.99575pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       " <defs>\n",
       "  <style type=\"text/css\">\n",
       "*{stroke-linecap:butt;stroke-linejoin:round;}\n",
       "  </style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M -0 101.818906 \n",
       "L 186.99575 101.818906 \n",
       "L 186.99575 0 \n",
       "L -0 0 \n",
       "z\n",
       "\" style=\"fill:none;\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 145.840625 59.13 \n",
       "L 145.840625 36.81 \n",
       "L 34.240625 36.81 \n",
       "z\n",
       "\" style=\"fill:#ffffff;\"/>\n",
       "   </g>\n",
       "   <g clip-path=\"url(#p701e67519e)\">\n",
       "    <image height=\"23\" id=\"image3d3e63616f\" transform=\"scale(1 -1)translate(0 -23)\" width=\"112\" x=\"34.240625\" xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAHAAAAAXCAYAAADTEcupAAAABHNCSVQICAgIfAhkiAAAAHhJREFUaIHt2LEJgEAQBdF/mlqfLRhag12Y2ZwtmAlebiZyHAPzClgWhk223Mf2REmScV57r/DZ0HsB/WNAOAPCGRDOgHAGhDMgnAHhDAhnQDgDwpUlU5Nf6H6dLcbqxQuEMyCcAeEMCGdAOAPCGRDOgHAGhDMgXAXNiAfzAcIdmQAAAABJRU5ErkJggg==\" y=\"-36.13\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" id=\"m3196398196\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"39.820625\" xlink:href=\"#m3196398196\" y=\"59.13\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 31.78125 66.40625 \n",
       "Q 24.171875 66.40625 20.328125 58.90625 \n",
       "Q 16.5 51.421875 16.5 36.375 \n",
       "Q 16.5 21.390625 20.328125 13.890625 \n",
       "Q 24.171875 6.390625 31.78125 6.390625 \n",
       "Q 39.453125 6.390625 43.28125 13.890625 \n",
       "Q 47.125 21.390625 47.125 36.375 \n",
       "Q 47.125 51.421875 43.28125 58.90625 \n",
       "Q 39.453125 66.40625 31.78125 66.40625 \n",
       "z\n",
       "M 31.78125 74.21875 \n",
       "Q 44.046875 74.21875 50.515625 64.515625 \n",
       "Q 56.984375 54.828125 56.984375 36.375 \n",
       "Q 56.984375 17.96875 50.515625 8.265625 \n",
       "Q 44.046875 -1.421875 31.78125 -1.421875 \n",
       "Q 19.53125 -1.421875 13.0625 8.265625 \n",
       "Q 6.59375 17.96875 6.59375 36.375 \n",
       "Q 6.59375 54.828125 13.0625 64.515625 \n",
       "Q 19.53125 74.21875 31.78125 74.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-48\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(36.639375 73.728437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_2\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"95.620625\" xlink:href=\"#m3196398196\" y=\"59.13\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 5 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.796875 72.90625 \n",
       "L 49.515625 72.90625 \n",
       "L 49.515625 64.59375 \n",
       "L 19.828125 64.59375 \n",
       "L 19.828125 46.734375 \n",
       "Q 21.96875 47.46875 24.109375 47.828125 \n",
       "Q 26.265625 48.1875 28.421875 48.1875 \n",
       "Q 40.625 48.1875 47.75 41.5 \n",
       "Q 54.890625 34.8125 54.890625 23.390625 \n",
       "Q 54.890625 11.625 47.5625 5.09375 \n",
       "Q 40.234375 -1.421875 26.90625 -1.421875 \n",
       "Q 22.3125 -1.421875 17.546875 -0.640625 \n",
       "Q 12.796875 0.140625 7.71875 1.703125 \n",
       "L 7.71875 11.625 \n",
       "Q 12.109375 9.234375 16.796875 8.0625 \n",
       "Q 21.484375 6.890625 26.703125 6.890625 \n",
       "Q 35.15625 6.890625 40.078125 11.328125 \n",
       "Q 45.015625 15.765625 45.015625 23.390625 \n",
       "Q 45.015625 31 40.078125 35.4375 \n",
       "Q 35.15625 39.890625 26.703125 39.890625 \n",
       "Q 22.75 39.890625 18.8125 39.015625 \n",
       "Q 14.890625 38.140625 10.796875 36.28125 \n",
       "z\n",
       "\" id=\"DejaVuSans-53\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(92.439375 73.728437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_3\">\n",
       "     <!-- Keys -->\n",
       "     <defs>\n",
       "      <path d=\"M 9.8125 72.90625 \n",
       "L 19.671875 72.90625 \n",
       "L 19.671875 42.09375 \n",
       "L 52.390625 72.90625 \n",
       "L 65.09375 72.90625 \n",
       "L 28.90625 38.921875 \n",
       "L 67.671875 0 \n",
       "L 54.6875 0 \n",
       "L 19.671875 35.109375 \n",
       "L 19.671875 0 \n",
       "L 9.8125 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-75\"/>\n",
       "      <path d=\"M 56.203125 29.59375 \n",
       "L 56.203125 25.203125 \n",
       "L 14.890625 25.203125 \n",
       "Q 15.484375 15.921875 20.484375 11.0625 \n",
       "Q 25.484375 6.203125 34.421875 6.203125 \n",
       "Q 39.59375 6.203125 44.453125 7.46875 \n",
       "Q 49.3125 8.734375 54.109375 11.28125 \n",
       "L 54.109375 2.78125 \n",
       "Q 49.265625 0.734375 44.1875 -0.34375 \n",
       "Q 39.109375 -1.421875 33.890625 -1.421875 \n",
       "Q 20.796875 -1.421875 13.15625 6.1875 \n",
       "Q 5.515625 13.8125 5.515625 26.8125 \n",
       "Q 5.515625 40.234375 12.765625 48.109375 \n",
       "Q 20.015625 56 32.328125 56 \n",
       "Q 43.359375 56 49.78125 48.890625 \n",
       "Q 56.203125 41.796875 56.203125 29.59375 \n",
       "z\n",
       "M 47.21875 32.234375 \n",
       "Q 47.125 39.59375 43.09375 43.984375 \n",
       "Q 39.0625 48.390625 32.421875 48.390625 \n",
       "Q 24.90625 48.390625 20.390625 44.140625 \n",
       "Q 15.875 39.890625 15.1875 32.171875 \n",
       "z\n",
       "\" id=\"DejaVuSans-101\"/>\n",
       "      <path d=\"M 32.171875 -5.078125 \n",
       "Q 28.375 -14.84375 24.75 -17.8125 \n",
       "Q 21.140625 -20.796875 15.09375 -20.796875 \n",
       "L 7.90625 -20.796875 \n",
       "L 7.90625 -13.28125 \n",
       "L 13.1875 -13.28125 \n",
       "Q 16.890625 -13.28125 18.9375 -11.515625 \n",
       "Q 21 -9.765625 23.484375 -3.21875 \n",
       "L 25.09375 0.875 \n",
       "L 2.984375 54.6875 \n",
       "L 12.5 54.6875 \n",
       "L 29.59375 11.921875 \n",
       "L 46.6875 54.6875 \n",
       "L 56.203125 54.6875 \n",
       "z\n",
       "\" id=\"DejaVuSans-121\"/>\n",
       "      <path d=\"M 44.28125 53.078125 \n",
       "L 44.28125 44.578125 \n",
       "Q 40.484375 46.53125 36.375 47.5 \n",
       "Q 32.28125 48.484375 27.875 48.484375 \n",
       "Q 21.1875 48.484375 17.84375 46.4375 \n",
       "Q 14.5 44.390625 14.5 40.28125 \n",
       "Q 14.5 37.15625 16.890625 35.375 \n",
       "Q 19.28125 33.59375 26.515625 31.984375 \n",
       "L 29.59375 31.296875 \n",
       "Q 39.15625 29.25 43.1875 25.515625 \n",
       "Q 47.21875 21.78125 47.21875 15.09375 \n",
       "Q 47.21875 7.46875 41.1875 3.015625 \n",
       "Q 35.15625 -1.421875 24.609375 -1.421875 \n",
       "Q 20.21875 -1.421875 15.453125 -0.5625 \n",
       "Q 10.6875 0.296875 5.421875 2 \n",
       "L 5.421875 11.28125 \n",
       "Q 10.40625 8.6875 15.234375 7.390625 \n",
       "Q 20.0625 6.109375 24.8125 6.109375 \n",
       "Q 31.15625 6.109375 34.5625 8.28125 \n",
       "Q 37.984375 10.453125 37.984375 14.40625 \n",
       "Q 37.984375 18.0625 35.515625 20.015625 \n",
       "Q 33.0625 21.96875 24.703125 23.78125 \n",
       "L 21.578125 24.515625 \n",
       "Q 13.234375 26.265625 9.515625 29.90625 \n",
       "Q 5.8125 33.546875 5.8125 39.890625 \n",
       "Q 5.8125 47.609375 11.28125 51.796875 \n",
       "Q 16.75 56 26.8125 56 \n",
       "Q 31.78125 56 36.171875 55.265625 \n",
       "Q 40.578125 54.546875 44.28125 53.078125 \n",
       "z\n",
       "\" id=\"DejaVuSans-115\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(78.125 87.406562)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-75\"/>\n",
       "      <use x=\"65.498047\" xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"127.021484\" xlink:href=\"#DejaVuSans-121\"/>\n",
       "      <use x=\"186.201172\" xlink:href=\"#DejaVuSans-115\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" id=\"m73ef145e1c\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"34.240625\" xlink:href=\"#m73ef145e1c\" y=\"42.39\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(20.878125 46.189219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"34.240625\" xlink:href=\"#m73ef145e1c\" y=\"53.55\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 1 -->\n",
       "      <defs>\n",
       "       <path d=\"M 12.40625 8.296875 \n",
       "L 28.515625 8.296875 \n",
       "L 28.515625 63.921875 \n",
       "L 10.984375 60.40625 \n",
       "L 10.984375 69.390625 \n",
       "L 28.421875 72.90625 \n",
       "L 38.28125 72.90625 \n",
       "L 38.28125 8.296875 \n",
       "L 54.390625 8.296875 \n",
       "L 54.390625 0 \n",
       "L 12.40625 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-49\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 57.349219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- Queries -->\n",
       "     <defs>\n",
       "      <path d=\"M 39.40625 66.21875 \n",
       "Q 28.65625 66.21875 22.328125 58.203125 \n",
       "Q 16.015625 50.203125 16.015625 36.375 \n",
       "Q 16.015625 22.609375 22.328125 14.59375 \n",
       "Q 28.65625 6.59375 39.40625 6.59375 \n",
       "Q 50.140625 6.59375 56.421875 14.59375 \n",
       "Q 62.703125 22.609375 62.703125 36.375 \n",
       "Q 62.703125 50.203125 56.421875 58.203125 \n",
       "Q 50.140625 66.21875 39.40625 66.21875 \n",
       "z\n",
       "M 53.21875 1.3125 \n",
       "L 66.21875 -12.890625 \n",
       "L 54.296875 -12.890625 \n",
       "L 43.5 -1.21875 \n",
       "Q 41.890625 -1.3125 41.03125 -1.359375 \n",
       "Q 40.1875 -1.421875 39.40625 -1.421875 \n",
       "Q 24.03125 -1.421875 14.8125 8.859375 \n",
       "Q 5.609375 19.140625 5.609375 36.375 \n",
       "Q 5.609375 53.65625 14.8125 63.9375 \n",
       "Q 24.03125 74.21875 39.40625 74.21875 \n",
       "Q 54.734375 74.21875 63.90625 63.9375 \n",
       "Q 73.09375 53.65625 73.09375 36.375 \n",
       "Q 73.09375 23.6875 67.984375 14.640625 \n",
       "Q 62.890625 5.609375 53.21875 1.3125 \n",
       "z\n",
       "\" id=\"DejaVuSans-81\"/>\n",
       "      <path d=\"M 8.5 21.578125 \n",
       "L 8.5 54.6875 \n",
       "L 17.484375 54.6875 \n",
       "L 17.484375 21.921875 \n",
       "Q 17.484375 14.15625 20.5 10.265625 \n",
       "Q 23.53125 6.390625 29.59375 6.390625 \n",
       "Q 36.859375 6.390625 41.078125 11.03125 \n",
       "Q 45.3125 15.671875 45.3125 23.6875 \n",
       "L 45.3125 54.6875 \n",
       "L 54.296875 54.6875 \n",
       "L 54.296875 0 \n",
       "L 45.3125 0 \n",
       "L 45.3125 8.40625 \n",
       "Q 42.046875 3.421875 37.71875 1 \n",
       "Q 33.40625 -1.421875 27.6875 -1.421875 \n",
       "Q 18.265625 -1.421875 13.375 4.4375 \n",
       "Q 8.5 10.296875 8.5 21.578125 \n",
       "z\n",
       "M 31.109375 56 \n",
       "z\n",
       "\" id=\"DejaVuSans-117\"/>\n",
       "      <path d=\"M 41.109375 46.296875 \n",
       "Q 39.59375 47.171875 37.8125 47.578125 \n",
       "Q 36.03125 48 33.890625 48 \n",
       "Q 26.265625 48 22.1875 43.046875 \n",
       "Q 18.109375 38.09375 18.109375 28.8125 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.1875 \n",
       "Q 20.953125 51.171875 25.484375 53.578125 \n",
       "Q 30.03125 56 36.53125 56 \n",
       "Q 37.453125 56 38.578125 55.875 \n",
       "Q 39.703125 55.765625 41.0625 55.515625 \n",
       "z\n",
       "\" id=\"DejaVuSans-114\"/>\n",
       "      <path d=\"M 9.421875 54.6875 \n",
       "L 18.40625 54.6875 \n",
       "L 18.40625 0 \n",
       "L 9.421875 0 \n",
       "z\n",
       "M 9.421875 75.984375 \n",
       "L 18.40625 75.984375 \n",
       "L 18.40625 64.59375 \n",
       "L 9.421875 64.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-105\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(14.798437 67.277031)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-81\"/>\n",
       "      <use x=\"78.710938\" xlink:href=\"#DejaVuSans-117\"/>\n",
       "      <use x=\"142.089844\" xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"203.613281\" xlink:href=\"#DejaVuSans-114\"/>\n",
       "      <use x=\"244.726562\" xlink:href=\"#DejaVuSans-105\"/>\n",
       "      <use x=\"272.509766\" xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"334.033203\" xlink:href=\"#DejaVuSans-115\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 34.240625 36.81 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 145.840625 59.13 \n",
       "L 145.840625 36.81 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 145.840625 59.13 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 34.240625 36.81 \n",
       "L 145.840625 36.81 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "  </g>\n",
       "  <g id=\"axes_2\">\n",
       "   <g id=\"patch_7\">\n",
       "    <path clip-path=\"url(#pcb79821179)\" d=\"M 152.815625 88.74 \n",
       "L 152.815625 88.421484 \n",
       "L 152.815625 7.518516 \n",
       "L 152.815625 7.2 \n",
       "L 156.892625 7.2 \n",
       "L 156.892625 7.518516 \n",
       "L 156.892625 88.421484 \n",
       "L 156.892625 88.74 \n",
       "z\n",
       "\" style=\"fill:#ffffff;stroke:#ffffff;stroke-linejoin:miter;stroke-width:0.01;\"/>\n",
       "   </g>\n",
       "   <image height=\"82\" id=\"image04009f93df\" transform=\"scale(1 -1)translate(0 -82)\" width=\"4\" x=\"153\" xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAAQAAABSCAYAAABzJnWUAAAABHNCSVQICAgIfAhkiAAAAJ5JREFUOI2dkjsOQjEMBI307n9UKgoU/2jJ7JMMSZfReNeK8ujXs+3rXJZhO6giyAH0aJwAWSwctaNhSRAc8cXQ/zO0djRYexM6ZhwYUpt89R/2AGgxJNTn1QXwjwlwgBJjyaaSUYNRbNEMgmBoZsOI0SDI5EjICIwQQ2qbxn63K02MA2CDcbMHgDPDabwLYAngyGKojEiLGFK7P7HZBwTeyrmHgLkCAAAAAElFTkSuQmCC\" y=\"-6\"/>\n",
       "   <g id=\"matplotlib.axis_3\"/>\n",
       "   <g id=\"matplotlib.axis_4\">\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 3.5 0 \n",
       "\" id=\"m7f82581462\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#m7f82581462\" y=\"88.74\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 0.0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.6875 12.40625 \n",
       "L 21 12.40625 \n",
       "L 21 0 \n",
       "L 10.6875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-46\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(163.892625 92.539219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#m7f82581462\" y=\"56.124\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0.2 -->\n",
       "      <defs>\n",
       "       <path d=\"M 19.1875 8.296875 \n",
       "L 53.609375 8.296875 \n",
       "L 53.609375 0 \n",
       "L 7.328125 0 \n",
       "L 7.328125 8.296875 \n",
       "Q 12.9375 14.109375 22.625 23.890625 \n",
       "Q 32.328125 33.6875 34.8125 36.53125 \n",
       "Q 39.546875 41.84375 41.421875 45.53125 \n",
       "Q 43.3125 49.21875 43.3125 52.78125 \n",
       "Q 43.3125 58.59375 39.234375 62.25 \n",
       "Q 35.15625 65.921875 28.609375 65.921875 \n",
       "Q 23.96875 65.921875 18.8125 64.3125 \n",
       "Q 13.671875 62.703125 7.8125 59.421875 \n",
       "L 7.8125 69.390625 \n",
       "Q 13.765625 71.78125 18.9375 73 \n",
       "Q 24.125 74.21875 28.421875 74.21875 \n",
       "Q 39.75 74.21875 46.484375 68.546875 \n",
       "Q 53.21875 62.890625 53.21875 53.421875 \n",
       "Q 53.21875 48.921875 51.53125 44.890625 \n",
       "Q 49.859375 40.875 45.40625 35.40625 \n",
       "Q 44.1875 33.984375 37.640625 27.21875 \n",
       "Q 31.109375 20.453125 19.1875 8.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-50\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(163.892625 59.923219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#m7f82581462\" y=\"23.508\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 0.4 -->\n",
       "      <defs>\n",
       "       <path d=\"M 37.796875 64.3125 \n",
       "L 12.890625 25.390625 \n",
       "L 37.796875 25.390625 \n",
       "z\n",
       "M 35.203125 72.90625 \n",
       "L 47.609375 72.90625 \n",
       "L 47.609375 25.390625 \n",
       "L 58.015625 25.390625 \n",
       "L 58.015625 17.1875 \n",
       "L 47.609375 17.1875 \n",
       "L 47.609375 0 \n",
       "L 37.796875 0 \n",
       "L 37.796875 17.1875 \n",
       "L 4.890625 17.1875 \n",
       "L 4.890625 26.703125 \n",
       "z\n",
       "\" id=\"DejaVuSans-52\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(163.892625 27.307219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-52\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"patch_8\">\n",
       "    <path d=\"M 152.815625 88.74 \n",
       "L 152.815625 88.421484 \n",
       "L 152.815625 7.518516 \n",
       "L 152.815625 7.2 \n",
       "L 156.892625 7.2 \n",
       "L 156.892625 7.518516 \n",
       "L 156.892625 88.421484 \n",
       "L 156.892625 88.74 \n",
       "z\n",
       "\" style=\"fill:none;stroke:#000000;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"p701e67519e\">\n",
       "   <rect height=\"22.32\" width=\"111.6\" x=\"34.240625\" y=\"36.81\"/>\n",
       "  </clipPath>\n",
       "  <clipPath id=\"pcb79821179\">\n",
       "   <rect height=\"81.54\" width=\"4.077\" x=\"152.815625\" y=\"7.2\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 180x180 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n",
    "                 xlabel='Keys', ylabel='Queries')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
